/**
 * 
 */
package edu.umd.clip.lm.model.training;

import java.util.HashMap;
import java.util.Map;

import edu.umd.clip.lm.model.data.Context;
import edu.umd.clip.lm.model.data.ContextFuturesPair;
import edu.umd.clip.lm.model.data.TrainingDataBlock;
import edu.umd.clip.lm.model.data.TupleCountPair;
import edu.umd.clip.lm.util.CompactReadOnlyLong2ObjectHashMap;
import edu.umd.clip.lm.util.ConstCountDistribution;
import edu.umd.clip.lm.util.Long2IntMap;

public class ContextVariableStats {
	/**
	 * 
	 */
	public ConstCountDistribution wordCounts;
	/**
	 * 
	 */
	public ConstCountDistribution contextCounts;
	/**
	 * 
	 */
	public CompactReadOnlyLong2ObjectHashMap<ConstCountDistribution> w2xDistributions;
	/**
	 * 
	 */
	public CompactReadOnlyLong2ObjectHashMap<ConstCountDistribution> x2wDistributions;
	
	public final boolean hasWordCounts;
	public final boolean hasContextCounts;
	public final boolean hasW2X;
	public final boolean hasX2W;
	
	public final static int MAX_BLOCKS = 3;
	public int processedBlocks;
	/**
	 * @param hasWordCounts
	 * @param hasContextCounts
	 * @param hasW2X
	 * @param hasX2W
	 */
	public ContextVariableStats(boolean hasWordCounts,
			boolean hasContextCounts, boolean hasW2X, boolean hasX2W) {
		this.hasWordCounts = hasWordCounts;
		this.hasContextCounts = hasContextCounts;
		this.hasW2X = hasW2X;
		this.hasX2W = hasX2W;
		
		this.processedBlocks = 0;
	}
	
	public void processBlock(ContextVariable ctxVar, TrainingDataBlock block) {
		if (processedBlocks >= MAX_BLOCKS) return;
		++processedBlocks;
		
		/* 
		 * it seems wasteful to convert between compact read-only and modifiable representations for every block,
		 * but only the first few levels in the tree will have more than one block
		 * additionally, we limit the number of blocks to consider because, after some point, 
		 * extra data should not change the scores much. 
		 */
		if (this.hasWordCounts) {
			Long2IntMap wordCounts = this.wordCounts == null ? new Long2IntMap() : this.wordCounts.toLong2IntMap();
			
			for(final ContextFuturesPair pair : block) {
				final TupleCountPair futures[] = pair.getFutures();
				
				for(TupleCountPair tc : futures) {
					wordCounts.addAndGet(tc.tuple, tc.count);
				}
			}
			
			this.wordCounts = new ConstCountDistribution(wordCounts);
		}

		if (this.hasContextCounts) {
			Long2IntMap contextCounts = this.contextCounts == null ? new Long2IntMap() : this.contextCounts.toLong2IntMap();

			for(final ContextFuturesPair pair : block) {
				final Context context = pair.getContext();
				final TupleCountPair futures[] = pair.getFutures();
				
				long varCount = 0;
				for(TupleCountPair tc : futures) {
					varCount += tc.count;
				}
				contextCounts.addAndGet(ctxVar.getIntValue(context), (int) varCount);
			}
			this.contextCounts = new ConstCountDistribution(contextCounts);
		}
		
		if (this.hasX2W) {
			HashMap<Long, Long2IntMap> x2w;
			if (this.x2wDistributions == null) {
				x2w = new HashMap<Long, Long2IntMap>();
			} else {
				long x[] = this.x2wDistributions.keys();
				ConstCountDistribution w[] = this.x2wDistributions.values(new ConstCountDistribution[this.x2wDistributions.size()]);
				x2w = new HashMap<Long, Long2IntMap>(x.length);
				for(int i=0; i<x.length; ++i) {
					x2w.put(x[i], w[i].toLong2IntMap());
				}
			}
			
			for(final ContextFuturesPair pair : block) {
				final Context context = pair.getContext();
				final TupleCountPair futures[] = pair.getFutures();
				
				Long longVarValue = Long.valueOf(ctxVar.getIntValue(context));
				Long2IntMap dist = x2w.get(longVarValue);
				if (dist == null) {
					dist = new Long2IntMap();
					x2w.put(longVarValue, dist);
				}
				for(TupleCountPair tc : futures) {
					dist.addAndGet(tc.tuple, tc.count);
				}
			}
			
			long x[] = new long[x2w.size()];
			ConstCountDistribution w[] = new ConstCountDistribution[x2w.size()];
			int pos = 0;
			for(Map.Entry<Long, Long2IntMap> e : x2w.entrySet()) {
				x[pos] = e.getKey().longValue();
				w[pos] = new ConstCountDistribution(e.getValue());
				++pos;
			}
			this.x2wDistributions = new CompactReadOnlyLong2ObjectHashMap<ConstCountDistribution>(x, w);
		}
		
		if (this.hasW2X) {
			HashMap<Long, Long2IntMap> w2x;
			if (this.w2xDistributions == null) {
				w2x = new HashMap<Long, Long2IntMap>();
			} else {
				long x[] = this.w2xDistributions.keys();
				ConstCountDistribution w[] = this.w2xDistributions.values(new ConstCountDistribution[this.w2xDistributions.size()]);
				w2x = new HashMap<Long, Long2IntMap>(x.length);
				for(int i=0; i<x.length; ++i) {
					w2x.put(x[i], w[i].toLong2IntMap());
				}
			}
			
			for(final ContextFuturesPair pair : block) {
				final Context context = pair.getContext();
				final TupleCountPair futures[] = pair.getFutures();
				
				final int varValue = ctxVar.getIntValue(context);
				
				for(TupleCountPair tc : futures) {
					Long word = tc.tuple;
					Long2IntMap dist = w2x.get(word);
					if (dist == null) {
						dist = new Long2IntMap();
						w2x.put(word, dist);
					}
					dist.addAndGet(varValue, tc.count);
				}
			}
			
			long x[] = new long[w2x.size()];
			ConstCountDistribution w[] = new ConstCountDistribution[w2x.size()];
			int pos = 0;
			for(Map.Entry<Long, Long2IntMap> e : w2x.entrySet()) {
				x[pos] = e.getKey().longValue();
				w[pos] = new ConstCountDistribution(e.getValue());
				++pos;
			}
			this.w2xDistributions = new CompactReadOnlyLong2ObjectHashMap<ConstCountDistribution>(x, w);
		}

	}
}